Two sum BST [Stack]

Time: O(N); Space: O(N); medium

Given two binary search trees, return True if and only if there is a node in the first tree and a node in the second tree whose values sum up to a given integer target.

Example 1:

Input: root1 = [2,1,4], root2 = [1,0,3], target = 5

Output: True

Explanation:

  • 2 and 3 sum up to 5.

Example 2:

Input: root1 = [0,-10,10], root2 = [5,1,7,0,2], target = 18

Output: False

Constraints:

  • Each tree has at most 5000 nodes.

  • -10^9 <= target, node.val <= 10^9

[20]:
class TreeNode(object):
    def __init__(self, x):
        self.val = x
        self.left = None
        self.right = None

1. Using stack [O(M+N), O(M+N)]

[21]:
class Solution1(object):
    """
    Time: O(M+N)
    Space: O(M+N)
    """
    def twoSumBSTs(self, root1, root2, target):
        stack1, stack2 = [], []

        while True:
            while root1:
                stack1.append(root1)
                root1 = root1.left
            while root2:
                stack2.append(root2)
                root2 = root2.right

            if not len(stack1) or not len(stack2):
                return False

            peek1, peek2 = stack1[-1], stack2[-1]

            if peek1.val + peek2.val == target:
                return True
            if peek1.val + peek2.val < target:
                root1 = stack1.pop().right
            else:
                root2 = stack2.pop().left
[22]:
s = Solution1()

root1 = TreeNode(2)
root1.left = TreeNode(1)
root1.right = TreeNode(4)
root2 = TreeNode(1)
root2.left = TreeNode(0)
root2.right = TreeNode(3)
target = 5
assert s.twoSumBSTs(root1, root2, target) == True

root1 = TreeNode(0)
root1.left = TreeNode(-10)
root1.right = TreeNode(10)
root2 = TreeNode(5)
root2.left = TreeNode(1)
root2.right = TreeNode(7)
root2.left.left = TreeNode(0)
root2.left.right = TreeNode(2)
target = 18
assert s.twoSumBSTs(root1, root2, target) == False

2. Using stack [O(N), O(N)]

[23]:
class Solution2(object):
    """
    Time: O(N)
    Space: O(N)
    """
    def twoSumBSTs(self, root1, root2, target):
        """
        :type root1: TreeNode
        :type root2: TreeNode
        :type target: int
        :rtype: bool
        """
        def inorder_gen(root, asc=True):
            result, stack = [], [(root, False)]

            while stack:
                root, is_visited = stack.pop()
                if root is None:
                    continue

                if is_visited:
                    yield root.val
                else:
                    if asc:
                        stack.append((root.right, False))
                        stack.append((root, True))
                        stack.append((root.left, False))
                    else:
                        stack.append((root.left, False))
                        stack.append((root, True))
                        stack.append((root.right, False))
            yield None


        left_gen, right_gen = inorder_gen(root1, True), inorder_gen(root2, False)

        left, right = next(left_gen), next(right_gen)

        while left is not None and right is not None:
            if left + right < target:
                left = next(left_gen)
            elif left + right > target:
                right = next(right_gen)
            else:
                return True

        return False
[24]:
s = Solution2()

root1 = TreeNode(2)
root1.left = TreeNode(1)
root1.right = TreeNode(4)
root2 = TreeNode(1)
root2.left = TreeNode(0)
root2.right = TreeNode(3)
target = 5
assert s.twoSumBSTs(root1, root2, target) == True

root1 = TreeNode(0)
root1.left = TreeNode(-10)
root1.right = TreeNode(10)
root2 = TreeNode(5)
root2.left = TreeNode(1)
root2.right = TreeNode(7)
root2.left.left = TreeNode(0)
root2.left.right = TreeNode(2)
target = 18
assert s.twoSumBSTs(root1, root2, target) == False

3. Recursion

[25]:
class Solution3(object):
    def twoSumBSTs(self, root1, root2, target):
        """
        :type root1: TreeNode
        :type root2: TreeNode
        :type target: int
        :rtype: bool
        """
        dict1 = {}
        dict2 = {}

        def recursive(node, dict):
            if node != None:
                dict[node.val] = 1

            if node.left != None:
                recursive(node.left, dict)

            if node.right != None:
                recursive(node.right, dict)

        recursive(root1, dict1)
        recursive(root2, dict2)

        for val in dict1.keys():
            if target - val in dict2:
                return True

        return False
[26]:
s = Solution3()

root1 = TreeNode(2)
root1.left = TreeNode(1)
root1.right = TreeNode(4)
root2 = TreeNode(1)
root2.left = TreeNode(0)
root2.right = TreeNode(3)
target = 5
assert s.twoSumBSTs(root1, root2, target) == True

root1 = TreeNode(0)
root1.left = TreeNode(-10)
root1.right = TreeNode(10)
root2 = TreeNode(5)
root2.left = TreeNode(1)
root2.right = TreeNode(7)
root2.left.left = TreeNode(0)
root2.left.right = TreeNode(2)
target = 18
assert s.twoSumBSTs(root1, root2, target) == False

4. Binary Search [O(MLogN)]

[27]:
class Solution4(object):
    def twoSumBSTs(self, root1, root2, target):
        """
        """
        if not root1 or not root2:
            return False
        if root1.val + root2.val == target:
            return True
        if root1.val + root2.val > target:
            return self.twoSumHelper(root1.left, root2, target) or self.twoSumHelper(root1, root2.right, target)
        else:
            return self.twoSumHelper(root1.right, root2, target) or self.twoSumHelper(root1, root2.right, target)

    def twoSumHelper(self, root1, root2, target):
        """
        Time: O(MlogN)
        """
        if not root1 or not root2:
            return False
        if self.binarySearch(root2, target - root1.val):
            return True
        return self.twoSumHelper(root1.left, root2, target) or self.twoSumHelper(root1.right, root2, target)

    def binarySearch(self, root, val):
        if not root:
            return False
        if root.val == val:
            return True
        if root.val < val:
            return self.binarySearch(root.right, val)
        else:
            return self.binarySearch(root.left, val)
[28]:
s = Solution4()

root1 = TreeNode(2)
root1.left = TreeNode(1)
root1.right = TreeNode(4)
root2 = TreeNode(1)
root2.left = TreeNode(0)
root2.right = TreeNode(3)
target = 5
assert s.twoSumBSTs(root1, root2, target) == True

root1 = TreeNode(0)
root1.left = TreeNode(-10)
root1.right = TreeNode(10)
root2 = TreeNode(5)
root2.left = TreeNode(1)
root2.right = TreeNode(7)
root2.left.left = TreeNode(0)
root2.left.right = TreeNode(2)
target = 18
assert s.twoSumBSTs(root1, root2, target) == False